import contextlib
import dataclasses
import functools
import itertools
import logging
import operator
import re
from collections import namedtuple
from itertools import chain
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    List,
    NamedTuple,
    Optional,
    Set,
    Tuple,
    TYPE_CHECKING,
    Union,
)

import sympy
from sympy.printing.printer import Printer

import torch
import torch.fx
from torch.utils._sympy.value_ranges import ValueRanges

from .. import config, metrics
from ..utils import (
    DeferredLineBase,
    do_bench,
    free_symbol_startswith,
    IndentedBuffer,
    sympy_dot,
    sympy_subs,
    sympy_symbol,
    unique,
)
from ..virtualized import ops, OpsValue, V

if TYPE_CHECKING:
    from ..ir import TensorBox

schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")


def data_type_logger(msg):
    if schedule_log.isEnabledFor(logging.DEBUG):
        schedule_log.debug("Data type propagation: %s", msg)


TensorArg = namedtuple("TensorArg", ["name", "buffer", "dtype", "check_alignment"])
SizeArg = namedtuple("SizeArg", ["name", "expr"])

DeviceCodegen = namedtuple("DeviceCodegen", ["scheduling", "wrapper_codegen"])
device_codegens: Dict[str, DeviceCodegen] = {}


# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
    device: str, device_scheduling: type, device_wrapper_codegen: type
):
    device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)


def get_scheduling_for_device(device: str):
    return device_codegens[device].scheduling if device in device_codegens else None


def get_wrapper_codegen_for_device(device: str):
    return (
        device_codegens[device].wrapper_codegen if device in device_codegens else None
    )


def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
    from ..ir import FlexibleLayout

    # added contiguous index prevents reordering
    return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]


def get_device_op_overrides(device: str):
    assert isinstance(device, str)
    if device == "cuda":
        from .cuda.device_op_overrides import CUDADeviceOpOverrides

        return CUDADeviceOpOverrides()

    return DeviceOpOverrides()


@functools.lru_cache(None)
def boolean_ops():
    return (
        "is_inf",
        "is_nan",
        "bitwise_xor",
        "logical_not",
        "signbit",
        "le",
        "lt",
        "ge",
        "gt",
        "eq",
        "ne",
    )


DTYPE_TO_COMPUTATION_DTYPE = {
    torch.bfloat16: torch.float,
    torch.float16: torch.float,
    **{
        dtype: dtype
        for dtype in [
            torch.bool,
            torch.float32,
            torch.float64,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.uint8,
            torch.uint16,
            torch.uint32,
            torch.uint64,
        ]
    },
}


class DataTypePropagation:
    def __init__(self, body) -> None:
        self.body = body
        self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
            "root": body.root_block.graph
        }
        for k, v in body.subblocks.items():
            self.graphs[k] = v.graph

    def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
        inputs = node.all_input_nodes
        input_nodes = [
            n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
        ]
        if len(input_nodes) == 0:
            return None

        all_input_nodes_propogated = all(
            OptimizationContext.key in n.meta
            and n.meta[OptimizationContext.key].dtype is not None
            for n in input_nodes
        )
        if not all_input_nodes_propogated:
            return None

        return functools.reduce(
            torch.promote_types,
            [n.meta[OptimizationContext.key].dtype for n in input_nodes],
        )

    def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
        sub_graph = self.graphs[node.target]
        dtype = self.propagate_graph(sub_graph)
        assert dtype
        return dtype

    def deduce_node_dtype(self, node: torch.fx.Node):
        if node.target in boolean_ops():
            return torch.bool

        if node.op == "placeholder":
            return None

        if node.target == "output":
            # we can infer output node if it only have 1 arg
            if len(node.args) != 1:
                return None

        if node.target in (
            "to_dtype",
            "index_expr",
        ):
            return node.args[-1]

        if node.target in (
            "rand",
            "randn",
        ):
            return torch.float

        if node.target in (
            "get_index",
            "index_expr",
        ):
            return torch.int64

        if node.target in (
            "load",
            "store",
            "store_reduction",
        ):
            buf_name = node.args[1]
            return V.graph.get_dtype(buf_name)

        if node.target == operator.getitem:
            return self.deduce_node_dtype(node.args[0])

        assert isinstance(node.target, str)

        if node.target == "reduction":
            return node.args[1]

        if node.target == "constant":
            return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]]

        if node.target.startswith("masked_subblock"):
            return self.deduce_node_dtype_by_subgraph(node)

        return self.deduce_node_dtype_by_inputs(node)

    def propagate_graph(self, graph: torch.fx.Graph):
        assert graph.nodes
        graph_dtype = None
        # For masked_subblock, we use output's dtype to represent
        # the dtype of this subgraph. For other cases, graph_dtype
        # might be None
        for node in graph.nodes:
            if OptimizationContext.key in node.meta:
                opt_ctx = node.meta[OptimizationContext.key]
            else:
                opt_ctx = OptimizationContext()

            opt_ctx.dtype = self.deduce_node_dtype(node)
            node.meta[OptimizationContext.key] = opt_ctx
            if node.target == "output":
                graph_dtype = opt_ctx.dtype
        return graph_dtype

    def propagate(self):
        self.propagate_graph(self.graphs["root"])

    @classmethod
    def propagate_loopbody(cls, body):
        return cls(body).propagate()

    @classmethod
    def propagate_scheduler_node(cls, node):
        from ..ir import LoopBody
        from ..scheduler import SchedulerNode

        assert isinstance(node, SchedulerNode)
        assert isinstance(node._body, LoopBody)
        DataTypePropagation.propagate_loopbody(node._body)


class ExprPrinter(Printer):
    @staticmethod
    def paren(string):
        def all_in_parens(string):
            if string[0] != "(" or len(string) < 2:
                return False
            count = 1
            for i, char in enumerate(string[1:]):
                if char == "(":
                    count += 1
                elif char == ")":
                    count -= 1
                if count == 0 and i != len(string) - 2:
                    return False
            assert count == 0
            return True

        if (
            isinstance(string, CSEVariable)
            or re.match(r"^[a-z0-9_.]+$", string, re.I)
            or re.match(r"^\([^)]*\)$", string, re.I)
            or string == ""
        ):
            return string
        # don't put extra parens for strings that are already wrapped in parens
        if all_in_parens(string):
            return string
        return f"({string})"

    def _print_Infinity(self, expr):
        return "math.inf"

    def _print_NegativeInfinity(self, expr):
        return "-math.inf"

    def _print_Relational(self, expr):
        return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))

    def _print_Mul(self, expr):
        return "*".join(map(self.paren, map(self._print, expr.args)))

    def _print_Add(self, expr):
        return " + ".join(map(self.paren, map(self._print, expr.args)))

    def _print_Mod(self, expr):
        return " % ".join(map(self.paren, map(self._print, expr.args)))

    def _print_FloorDiv(self, expr):
        raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")

    def _print_CleanDiv(self, expr):
        return self._print_FloorDiv(expr)

    def _print_GreaterThan(self, expr):
        # GreaterThan:          >=
        # StrictlyGreaterThan:  >
        # Go figure...
        return " >= ".join(map(self.paren, map(self._print, expr.args)))

    def _print_align(self, expr):
        assert len(expr.args) == 1
        return f"align({self._print(expr.args[0])})"


class PythonPrinter(ExprPrinter):
    def _print_ModularIndexing(self, expr):
        x, div, mod = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        mod = self.paren(self.doprint(mod))
        if div != "1":
            x = f"({x} // {div})"
        return f"{x} % {mod}"

    def _print_FloorDiv(self, expr):
        x, div = expr.args
        x = self.paren(self.doprint(x))
        div = self.paren(self.doprint(div))
        return f"({x} // {div})"

    def _helper_sqrt(self, expr):
        return f"math.sqrt({self._print(expr)})"

    def _print_Pow(self, expr):
        # Pow() confuses triton
        base, exp = expr.args
        # NB: Remember this is sizevar computation!  You don't typically
        # expect to have to do floating point computation including exponents
        # in sizevar compute.  Instead of adding support for floating
        # point pow, you should make upstream retranslate the Sympy expression
        # into Tensor expressions earlier and do that instead.
        if exp == 0.5:
            return self._helper_sqrt(base)
        elif exp == -0.5:
            return "1/" + self._helper_sqrt(base)
        base = self._print(base)
        assert exp == int(exp), exp
        exp = int(exp)
        if exp > 0:
            return "*".join([self.paren(base)] * exp)
        elif exp < 0:
            return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
        else:  # exp == 0
            return "1"

    def _print_floor(self, expr):
        assert len(expr.args) == 1
        return f"math.floor({self._print(expr.args[0])})"

    def _print_ceiling(self, expr):
        assert len(expr.args) == 1
        return f"math.ceil({self._print(expr.args[0])})"

    def _print_Abs(self, expr):
        assert len(expr.args) == 1
        return f"abs({self._print(expr.args[0])})"

    def _print_Max(self, expr):
        assert len(expr.args) >= 2
        return f"max({', '.join(map(self._print, expr.args))})"

    def _print_Min(self, expr):
        assert len(expr.args) >= 2
        return f"min({', '.join(map(self._print, expr.args))})"

    def _print_cos(self, expr):
        assert len(expr.args) == 1
        return f"math.cos({self._print(expr.args[0])})"

    def _print_cosh(self, expr):
        assert len(expr.args) == 1
        return f"math.cosh({self._print(expr.args[0])})"

    def _print_acos(self, expr):
        assert len(expr.args) == 1
        return f"math.acos({self._print(expr.args[0])})"

    def _print_sin(self, expr):
        assert len(expr.args) == 1
        return f"math.sin({self._print(expr.args[0])})"

    def _print_sinh(self, expr):
        assert len(expr.args) == 1
        return f"math.sinh({self._print(expr.args[0])})"

    def _print_asin(self, expr):
        assert len(expr.args) == 1
        return f"math.asin({self._print(expr.args[0])})"

    def _print_tan(self, expr):
        assert len(expr.args) == 1
        return f"math.tan({self._print(expr.args[0])})"

    def _print_tanh(self, expr):
        assert len(expr.args) == 1
        return f"math.tanh({self._print(expr.args[0])})"

    def _print_atan(self, expr):
        assert len(expr.args) == 1
        return f"math.atan({self._print(expr.args[0])})"

    def _print_Round(self, expr):
        assert len(expr.args) == 1
        return f"round({self._print(expr.args[0])})"

    def _print_RoundDecimal(self, expr):
        assert len(expr.args) == 2
        number, ndigits = expr.args
        assert isinstance(ndigits, sympy.Integer)
        return f"round({self._print(number)}, {ndigits})"


class OpOverrides:
    def __init__(self, parent):
        super().__init__()
        self._parent = parent

    def __getattr__(self, item):
        return getattr(self._parent, item)

    @staticmethod
    def identity(value):
        # used to trigger cse
        return value

    @staticmethod
    def constant(value, dtype):
        return repr(value)

    @staticmethod
    def reciprocal(x):
        return ops.truediv("1", x)

    @staticmethod
    def square(x):
        return ops.mul(x, x)

    @staticmethod
    def bitwise_not(x):
        return f"~{ExprPrinter.paren(x)}"

    @staticmethod
    def logical_not(a):
        return f"{ExprPrinter.paren(a)} == 0"

    @staticmethod
    def bitwise_and(x, y):
        return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_or(x, y):
        return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_xor(x, y):
        return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"

    @staticmethod
    def bitwise_left_shift(x, y):
        return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"

    # TODO(fdrocha): this is currently not being used anywhere,
    # pending on moving triton pin past 972b761
    @staticmethod
    def bitwise_right_shift(x, y):
        return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"

    @staticmethod
    def remainder(a, b):
        r = ops.mod(a, b)
        return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)

    @staticmethod
    def load_seed(name, offset):
        return ops.load(name, sympy.Integer(offset))


class DeviceOpOverrides:
    def import_get_raw_stream_as(self, name):
        raise NotImplementedError()

    def set_device(self, device_idx):
        raise NotImplementedError()

    def synchronize(self):
        raise NotImplementedError()

    def device_guard(self, device_idx):
        raise NotImplementedError()


class DeferredLine(DeferredLineBase):
    """A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""

    def __init__(self, name, line):
        super().__init__(line)
        self.name = name
        assert not isinstance(line, DeferredLineBase)

    def __call__(self):
        if all(
            self.name not in x
            for x in (
                V.graph.removed_buffers,
                V.kernel.removed_buffers,
                V.graph.inplaced_to_remove,
                V.kernel.inplaced_to_remove,
            )
        ):
            return self.line
        return None

    def _new_line(self, line):
        return DeferredLine(self.name, line)


class BracesBuffer(IndentedBuffer):
    def indent(self, offset=1):
        @contextlib.contextmanager
        def ctx():
            for _ in range(offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(-offset):
                self._indent -= 1
                self.writeline("}")
            yield
            for _ in range(-offset):
                self.writeline("{")
                self._indent += 1
            for _ in range(offset):
                self._indent -= 1
                self.writeline("}")

        return ctx()


class InplacedBuffer(NamedTuple):
    inner_name: str
    other_names: List[str]


class KernelArgs:
    @staticmethod
    def _lookup(prefix, odict, name):
        assert isinstance(name, (str, sympy.Symbol))
        if name not in odict:
            odict[name] = f"{prefix}{len(odict)}"
        return odict[name]

    def __init__(self, sizevars=None):
        self.input_buffers = dict()
        self.output_buffers = dict()
        self.inplace_buffers = dict()
        self.sizevars = sizevars or dict()

    def __repr__(self):
        return "KernelArgs({})".format(
            ", ".join(
                map(
                    repr,
                    [
                        self.input_buffers,
                        self.output_buffers,
                        self.inplace_buffers,
                        self.sizevars,
                    ],
                )
            )
        )

    def _buffer_is_marked_removed(self, name):
        return isinstance(name, str) and name.startswith("REMOVED")

    def input(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.output_buffers:
            return self.output_buffers[name]
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        if name.startswith("seed"):
            return self._lookup("seed", self.input_buffers, name)
        return self._lookup("in_ptr", self.input_buffers, name)

    def output(self, name):
        if V.graph.scheduler:
            name = V.graph.scheduler.mutation_real_name.get(name, name)
        assert name not in V.graph.removed_buffers, name
        if name in self.inplace_buffers:
            return self.inplace_buffers[name].inner_name
        return self._lookup("out_ptr", self.output_buffers, name)

    def make_inplace(self, input_name, output_name):
        assert output_name not in self.inplace_buffers
        if input_name in self.inplace_buffers:
            buf = self.inplace_buffers[input_name]
            buf.other_names.append(output_name)
            self.inplace_buffers[output_name] = buf
        else:
            buf = InplacedBuffer(
                f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
                [input_name, output_name],
            )
            self.inplace_buffers[input_name] = buf
            self.inplace_buffers[output_name] = buf

    def seed_offset(self, name, value):
        if value in self.sizevars:
            return self.sizevars[value]
        if name in self.sizevars.values():
            name = (
                f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
            )
        self.sizevars[value] = name
        return name

    def size(self, name):
        if str(name) == "seed":
            self.sizevars["seed"] = "seed"
            return "seed"
        return self._lookup("ks", self.sizevars, name)

    def call_names(self):
        return chain(
            self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
        )

    def wrap_ptr_arg(self, buf, dtype):
        return f"c_void_p({buf}.data_ptr())"

    def wrap_size_arg(self, size):
        return f"c_long({size})"

    def cpp_argdefs(self):
        from .cpp import DTYPE_TO_CPP, INDEX_TYPE

        call_args = []
        arg_defs = []
        arg_types = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            outer = inplaced.other_names[-1]
            inner = inplaced.inner_name
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.input_buffers.items():
            if outer in self.inplace_buffers:
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"const {cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"const {cpp_dtype}*")
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            dtype = V.graph.get_dtype(outer)
            cpp_dtype = DTYPE_TO_CPP[dtype]
            arg_defs.append(f"{cpp_dtype}* {inner}")
            call_args.append(self.wrap_ptr_arg(outer, dtype))
            arg_types.append(f"{cpp_dtype}*")
        for outer, inner in self.sizevars.items():
            arg_defs.append(f"const {INDEX_TYPE} {inner}")
            call_args.append(self.wrap_size_arg(outer))
            arg_types.append(f"const {INDEX_TYPE}")
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)
        return arg_defs, call_args, arg_types

    def python_argdefs(self):
        arg_defs = []
        call_args = []
        precompile_args: List[Union[TensorArg, SizeArg]] = []
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            arg_defs.append(inplaced.inner_name)
            call_args.append(inplaced.other_names[-1])
            precompile_args.append(
                TensorArg(
                    inplaced.inner_name,
                    inplaced.other_names[-1],
                    V.graph.get_dtype(inplaced.other_names[-1]),
                    True,
                )
            )
        for outer, inner in chain(
            self.input_buffers.items(), self.output_buffers.items()
        ):
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            arg_defs.append(inner)
            call_args.append(outer)
            precompile_args.append(
                TensorArg(inner, outer, V.graph.get_dtype(outer), True)
            )
        for outer, inner in self.sizevars.items():
            arg_defs.append(inner)
            call_args.append(outer)
            precompile_args.append(SizeArg(inner, outer))
            if V.graph.wrapper_code:
                V.graph.wrapper_code.ensure_size_computed(outer)

        return arg_defs, call_args, precompile_args

    def aliases(self):
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            for other in inplaced.other_names:
                if (
                    other in V.graph.inplaced_to_remove
                    or other in V.kernel.inplaced_to_remove
                ):
                    continue
                if other in self.input_buffers:
                    yield self.input_buffers[other], inplaced.inner_name
                if other in self.output_buffers:
                    yield self.output_buffers[other], inplaced.inner_name

    def is_removed(self, name):
        def _is_removed(name, buffers):
            return name not in buffers or self._buffer_is_marked_removed(buffers[name])

        return _is_removed(name, self.output_buffers) and _is_removed(
            name, self.inplace_buffers
        )

    # Includes inplace buffers, excludes removed buffers.  Essentially,
    # after you do a call into this kernel, which buffers actually contain
    # updated data?  Modeled off of python_argdefs.
    def live_output_buffers(self):
        live_outs = set()
        for inplaced in unique(self.inplace_buffers.values()):
            if self._buffer_is_marked_removed(inplaced):
                continue
            live_outs.add(inplaced.other_names[-1])
        for outer, inner in self.output_buffers.items():
            if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
                continue
            live_outs.add(outer)
        return live_outs


class CSEVariable:
    """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
    To do so, the backends can simply overload `Kernel.create_cse_var`
    The "CSEVariable.update_on_args" method gives you a hook for annotations
    See example of TritonCSEVariable in triton.py
    """

    def __init__(self, name, bounds: ValueRanges):
        assert isinstance(bounds, ValueRanges)
        self.name = name
        self.bounds = bounds

    def __str__(self):
        return self.name

    def __hash__(self) -> int:
        return hash(self.name)

    def __eq__(self, other) -> bool:
        return type(other) == type(self) and other.name == self.name

    def update_on_args(self, name, args, kwargs):
        pass


class CppWrapperKernelArgs(KernelArgs):
    def wrap_ptr_arg(self, buf, dtype):
        from .cpp import DTYPE_TO_CPP

        if config.aot_inductor.abi_compatible:
            # In the abi_compatible model, we just return the buf here.
            # We will form correct call args later in wrapper.generate_kernel_all.
            return buf
        else:
            return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"

    def wrap_size_arg(self, size):
        return f"{size}"


class CSE:
    """Common subexpression elimination"""

    def __init__(
        self,
        prefix="",
        suffix="",
        name_prefix="tmp",
        iter_buffers=None,
        store_cache=None,
        reduction_cache=None,
        varname_map=None,
    ):
        self.prefix = prefix
        self.suffix = suffix
        self.cache = {}
        self.name_prefix = name_prefix
        self.store_cache = store_cache or {}
        self.reduction_cache = reduction_cache or {}
        self.iter_buffer_ids = iter_buffers or itertools.count()
        self.invalidated_stores = set()
        self.varname_map = varname_map or {}

    def invalidate(self, keep_vars: Set[str]):
        for name, tmp in list(self.store_cache.items()):
            if tmp not in keep_vars:
                del self.store_cache[name]
                self.invalidated_stores.add(name)
        self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}

    def clone(self):
        # Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
        return CSE(
            prefix=self.prefix,
            suffix=self.suffix,
            name_prefix=self.name_prefix,
            iter_buffers=self.iter_buffer_ids,
            store_cache=self.store_cache,
            varname_map=self.varname_map,
        )

    def generate(
        self,
        buffer: IndentedBuffer,
        expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
        *,
        bounds: ValueRanges = ValueRanges.unknown(),
        write=True,
        assignment=True,
    ) -> CSEVariable:
        if isinstance(expr, OpsValue):
            expr = expr.value

        assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
        assert write or assignment
        if isinstance(expr, CSEVariable):
            # If the expressions were always created with all the information, we could
            # assert expr.bounds == bounds, but sometimes the expression is created
            # with the loose ValueRanges.unknown(), so we need to tighten the bounds
            expr.bounds = expr.bounds.tighten(bounds)
            return expr
        cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
        var = self.cache.get(cache_key, None)
        if not var:
            var = self.newvar(bounds) if assignment else None
            self.cache[cache_key] = var
            if write:
                if V.kernel.current_node:
                    V.kernel.current_node.codegen_originating_info(
                        buffer, only_once=True
                    )
                if isinstance(expr, IndentedBuffer):
                    if assignment:
                        buffer.writeline(f"{self.prefix}{var} =")
                    buffer.splice(expr)
                    buffer.writeline(self.suffix)
                else:
                    if assignment:
                        line = f"{self.prefix}{var} = {expr}{self.suffix}"
                    else:
                        line = f"{expr}{self.suffix}"
                    buffer.writeline(line)
        else:
            var.bounds = var.bounds.tighten(bounds)

        return var

    def newvar(self, bounds: ValueRanges = ValueRanges.unknown()) -> CSEVariable:
        var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
        var = V.kernel.create_cse_var(var_name, bounds)
        self.varname_map[var_name] = var
        return var


class IndirectAssertLine(DeferredLineBase):
    def __init__(self, line, assert_fn, var, mask, size_map):
        self.var = var
        self.mask = mask
        self.line = line
        self.assert_fn = assert_fn
        self.size_map = size_map

    def __call__(self):
        size, size_str = self.size_map[(self.var, self.mask)]

        # We assert if we've not been able to prove the bound
        assert_min = (self.var.bounds.lower >= 0) != sympy.true
        assert_max = (self.var.bounds.upper < size) != sympy.true

        # FooBar interview question
        if not (assert_min or assert_max):
            return None
        elif assert_min and assert_max:
            # The conditions need to be in parens because of Python's operator precedence.
            # It'd be less error-prone to use and/or/not, which is suported by triton
            cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
            cond_print = f"0 <= {self.var} < {size_str}"
        elif assert_min:
            cond = f"0 <= {self.var}"
            cond_print = cond
        else:
            assert assert_max
            cond = f"{self.var} < {size_str}"
            cond_print = cond

        if self.mask:
            cond = f"({cond}) | ~{self.mask}"
        return self.line.format(
            assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
        )

    def _new_line(self, line):
        return IndirectAssertLine(
            line, self.assert_fn, self.var, self.mask, self.size_map
        )


class CodeGen:
    def __init__(self):
        super().__init__()
        self.exit_stack = contextlib.ExitStack()

    def __enter__(self):
        self.exit_stack.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.exit_stack.__exit__(exc_type, exc_val, exc_tb)


class Kernel(CodeGen):
    newvar_prefix = ""
    suffix = ""
    overrides = None
    load_format = None
    store_format = None

    def __init__(self, args=None, increase_kernel_count=True):
        super().__init__()
        if increase_kernel_count:
            metrics.generated_kernel_count += 1
        self.args = args or KernelArgs()
        self.loads = IndentedBuffer()
        self.compute = IndentedBuffer()
        self.stores = IndentedBuffer()
        self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
        self.must_keep_buffers = set()
        self.store_buffer_names = set()
        self._load_mask = None
        # set in set_current_node
        self.current_node = None
        self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
        # Upper bounds for indirect_indexing and their str representation
        self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {}

        self.removed_buffers = set()
        self.inplaced_to_remove = set()

        # key: the buffer to write
        # value: the buffer to read and whose memory can be reused for
        #   the buffer specified by key
        self.inplace_update_buffers = dict()
        # Set minimum number of elements processed per thread.
        self.min_elem_per_thread = 1

    @contextlib.contextmanager
    def set_current_node(self, node):
        prior = self.current_node
        self.current_node = node
        self.node_to_bounds = node._body.bounds().get_bounds()
        try:
            yield
        finally:
            self.current_node = prior

    @contextlib.contextmanager
    def swap_buffers(self, lb, cb=None, sb=None):
        if cb is None:
            cb = lb
        loads = self.loads
        compute = self.compute
        stores = self.stores
        cse = self.cse
        self.loads = lb
        self.compute = cb
        self.stores = sb
        self.cse = cse.clone()
        try:
            yield
        finally:
            self.loads = loads
            self.compute = compute
            self.stores = stores
            self.cse = cse

    def load(self, name: str, index: sympy.Expr):
        raise NotImplementedError()

    def indirect_load(self, name: str, index: sympy.Expr):
        """A load the depends on an index we have read"""
        prior = self.loads
        try:
            # put the load in the compute section as it might have deps
            self.loads = self.compute
            return self.load(name, index)
        finally:
            self.loads = prior

    def store_reduction(self, name, index, value):
        raise NotImplementedError()

    def store(self, name, index, value, mode=None):
        raise NotImplementedError()

    def reduction(self, dtype, src_dtype, reduction_type, value):
        raise NotImplementedError()

    def scan(self, dtype, combine_fn, value, init):
        raise NotImplementedError()

    def bucketize(
        self,
        values,
        offsets_name: str,
        offsets_size: sympy.Expr,
        indexing_dtype: torch.dtype,
        right: bool,
    ):
        """
        See [Note: Inductor bucketize op]
        """
        raise NotImplementedError()

    @property
    def assert_function(self) -> str:
        raise NotImplementedError()

    def index_to_str(self, index: sympy.Expr) -> str:
        raise NotImplementedError()

    def __enter__(self):
        class CSEProxy:
            self.name = "CSEProxy"

            @staticmethod
            def __getattr__(name: str) -> Callable[..., CSEVariable]:  # type: ignore[misc]
                def inner(*args, **kwargs):
                    # TritonTemplateKernel has no current_node
                    buf_bounds = ValueRanges.unknown()
                    if hasattr(V.interpreter, "current_node"):
                        fx_node = V.interpreter.current_node
                        assert isinstance(self.node_to_bounds, dict)
                        buf_bounds = self.node_to_bounds.get(
                            fx_node, ValueRanges.unknown()
                        )

                    csevar = self.cse.generate(
                        self.compute,
                        getattr(parent_handler, name)(*args, **kwargs),  # type: ignore[has-type]
                        bounds=buf_bounds,
                    )
                    csevar.update_on_args(name, args, kwargs)
                    return csevar

                return inner

            @staticmethod
            def indirect_indexing(var, size, check=True):
                # Skip CSE since this doesn't return an expression

                if var.bounds.lower < 0:
                    new_bounds = ValueRanges.unknown()
                    if var.bounds != ValueRanges.unknown() and isinstance(
                        size, sympy.Number
                    ):
                        # Take the negative part of the bound and add size to it
                        # Then take union of that and the positive part
                        # This is a tighter bound than that of a generic ops.where, as we have info on the cond
                        neg = var.bounds & ValueRanges(-sympy.oo, -1)
                        new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
                        # We don't have a good way of representing the empty range
                        if var.bounds.upper >= 0:
                            pos = var.bounds & ValueRanges(0, sympy.oo)
                            new_bounds = new_bounds | pos

                    stm = ops.add(var, self.rename_indexing(size))
                    # Mixed negative and non-negative
                    if var.bounds.upper >= 0:
                        lt = ops.lt(var, "0")
                        stm = ops.where(lt, stm, var)
                    new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)

                    new_var.update_on_args("index_wrap", (var,), {})
                    var = new_var

                if self.generate_assert(check):
                    mask = self.load_mask(var)

                    # An assertion line may have been written already, if so just
                    # update the max size.
                    map_key = (var, mask)
                    existing_size, _ = self.indirect_max_sizes.get(
                        map_key, (None, None)
                    )
                    if existing_size is not None:
                        size = sympy.Min(size, existing_size)
                    else:
                        line = (
                            '{assert_fn}({cond}, "index out of bounds: {cond_print}")'
                        )
                        self.compute.writeline(
                            IndirectAssertLine(
                                line,
                                self.assert_function,
                                var,
                                mask,
                                self.indirect_max_sizes,
                            )
                        )

                    self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
                return sympy_symbol(str(var))

            @staticmethod
            def load(name: str, index: sympy.Expr):
                if name in self.cse.invalidated_stores:
                    # A load from an invalidated store requires us to
                    # keep the actual buffer around
                    V.kernel.must_keep_buffers.add(name)
                if free_symbol_startswith(index, "tmp"):
                    return self.indirect_load(name, index)
                store_cache = self.cse.store_cache
                if name in store_cache:
                    return store_cache[name]
                return self.load(name, index)

            @staticmethod
            def store(name, index, value, mode=None):
                self.store_buffer_names.add(name)
                if mode is None:
                    self.cse.store_cache[name] = value
                    if self.current_node:
                        for other_name in self.current_node.get_mutations():
                            self.cse.store_cache[other_name] = value
                if name not in V.graph.removed_buffers:
                    return self.store(name, index, value, mode=mode)

            @staticmethod
            def store_reduction(name, index, value):
                self.store_buffer_names.add(name)
                self.cse.store_cache[name] = value
                if self.current_node:
                    for other_name in self.current_node.get_mutations():
                        self.cse.store_cache[other_name] = value

                if name not in V.graph.removed_buffers:
                    return self.store_reduction(name, index, value)

            @staticmethod
            def reduction(dtype, src_dtype, reduction_type, value):
                return self.reduction(dtype, src_dtype, reduction_type, value)

            @staticmethod
            def scan(dtype, combine_fn, value, init):
                return self.scan(dtype, combine_fn, value, init)

            @staticmethod
            def bucketize(
                values,
                offsets_name: str,
                offsets_size: sympy.Expr,
                indexing_dtype: torch.dtype,
                right: bool,
            ):
                """
                [Note: Inductor bucketize op]

                Given values (tensor) and offsets_name (reference to the name of a 1D
                tensor), calculate the bucket that each value belongs to.

                e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
                return =        [ 0, 1, 1, 1, 1, 3, 3, 4].

                When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
                When right == True,  bucket i refers to range [offsets[i], offsets[i+1]).

                Offsets must be non-decreasing or the result is undefined.
                """
                return self.bucketize(
                    values, offsets_name, offsets_size, indexing_dtype, right
                )

        super().__enter__()
        assert self.overrides
        parent_handler = self.overrides(V.get_ops_handler())
        self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
        self.exit_stack.enter_context(V.set_kernel_handler(self))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Note that V.graph.scheduler can be None when codegening triton template
        kernels.
        """
        if V.graph.scheduler:
            V.graph.scheduler.remove_kernel_local_buffers()
        super().__exit__(exc_type, exc_val, exc_tb)

    def generate_assert(self, check):
        return (check or config.debug_index_asserts) and config.assert_indirect_indexing

    def load_mask(self, var):
        # only the triton kernel requires mask
        return ""

    def rename_indexing(self, index) -> sympy.Expr:
        # adds the necessary kernel args for index expressions
        # and renames variables in index expressions to kernel arg names
        if isinstance(index, (list, tuple)):
            return [self.rename_indexing(x) for x in index]
        index = V.graph.sizevars.simplify(index)
        sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
        replacements = {
            x: self.args.size(x)
            for x in sorted_symbols
            if x.name.startswith("s")
            or x.name.startswith("ps")
            or (x.name.startswith("i") and not x.name.startswith("idx"))
        }
        return sympy_subs(index, replacements)

    def create_cse_var(self, *args, **kwargs):
        return CSEVariable(*args, **kwargs)


@dataclasses.dataclass
class OptimizationContext:
    key: ClassVar[str] = "opt_ctx"

    # Load value as mask
    is_load_as_mask: bool = False

    dtype: Optional[torch.dtype] = None
    ops_name: str = ""

    # Load uint8 value as float32
    is_load_uint8_as_float: bool = False


@functools.lru_cache(None)
def jinja2_env():
    try:
        import jinja2

        return jinja2.Environment(
            undefined=jinja2.StrictUndefined,
        )
    except ImportError:
        return None


class ChoiceCaller:
    """
    Represents a possible choice used in autotune_process.py.
    During autotuning, self.benchmark() is first called to get benchmark result,
    and if this choice is selected, self.output_node() is called to get the output_node.

    Children classes: TritonTemplateCaller, CUDATemplateCaller.
    """

    def __init__(self, name, input_nodes, layout):
        super().__init__()
        self.name = name
        self.layout = layout
        self.input_nodes = input_nodes

    def benchmark(self, *args, out) -> float:
        algo = self.to_callable()
        return do_bench(lambda: algo(*args, out=out))

    def call_name(self) -> str:
        raise NotImplementedError()

    def to_callable(self):
        raise NotImplementedError()

    def hash_key(self) -> str:
        raise NotImplementedError()

    def output_node(self) -> "TensorBox":
        raise NotImplementedError()


class KernelTemplate:
    """
    Base class for defining kernel templates.

    Children classes: TritonTemplate, CUDATemplate
    """

    @staticmethod
    def _template_from_string(source):
        env = jinja2_env()
        if env is not None:
            return env.from_string(source)
        return None

    @staticmethod
    def _fake_get_dtype(fake_out):
        _get_dtype_real = V.graph.get_dtype

        def get_dtype(name):
            if name == fake_out.get_name():
                return fake_out.get_dtype()
            return _get_dtype_real(name)

        return get_dtype

    def __init__(self, name: str):
        self.name = name

    def maybe_append_choice(self, choices, **kwargs):
        """
        Maybe generates a new ChoiceCaller and appends it into existing choices.

        choices: A list of ChoiceCallers.
        kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
        """

        try:
            choices.append(self.generate(**kwargs))
        except NotImplementedError:
            pass

    def generate(self, **kwargs) -> ChoiceCaller:
        """
        Generates a ChoiceCaller instance from the given arguments.
        """

        raise NotImplementedError()
